import os, pickle
import torch
from torch.utils.data import TensorDataset

wrap = ['B-', 'I-','E-','S-']
wgseg_tag = ['w']
label2id_dict = {'O':0}
#pos_tags  = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'i', 'j', 'k', 'l', 'm', 'n', 'o', 'p', 'q', 'r', 's', 't', 'u', 'v', 'w', 'x', 'y', 'z']

id2space = {0:'', 1:'', 2:'', 3:' ', 4:' '}

for i, tag in enumerate(wgseg_tag):
    for j, w in enumerate(wrap):
        one_tag = w + tag
        label2id_dict[one_tag] = len(label2id_dict)
id2label_dict = {v:k for k, v in label2id_dict.items()}

def half2full(s):
    n = []
    s = s.decode('utf-8')
    for char in s:
        num = char(char)
        if num == 320:
            num = 0x3000
        elif 0x21 <= num <= 0x7E:
            num += 0xfee0
        num = chr(num)
        n.append(num)
    return ''.join(n)

def full2half(strs):
    n = []
    for char in strs:
        num = ord(char)
        if num == 0x3000:
            num = 32
        elif 0xFF01 <= num <= 0xFF5E:
            num -= 0xfee0
        num = chr(num)
        n.append(num)
    return ''.join(n)



class Tokenize():
    def __init__(self, dict_path):
        self.dict_path = dict_path
        self.token2id = {}
        self.id2token = {}
        self.unkid = 100
        self.dict_size = 0
        self.load_dict()

    def load_dict(self,):
        for i, line in enumerate(open(self.dict_path, 'r', encoding='utf-8')):
            token = line.strip('\n')
            if token.rfind('##') == 0:
                continue
            self.id2token[i] = token
            self.token2id[token] = i
        self.dict_size = len(self.token2id)
    def sentence2id(self, line):
        ids = [self.token2id.get(c, self.unkid) for c in line]
        return ids
    def id2sentence(self, ids):
        line = [self.id2token.get(c) for c in ids]
        return ''.join(line)
    def convert_tokens_to_ids(self, line):
        return self.sentence2id(line)

class Examples:
    def __init__(self, tokens, label_ids):
        self.tokens = tokens
        self.label_ids = label_ids

    def __str__(self):
        return self.__repr__()

    def __repr__(self):
        s = ""
        s += f"tokens: {' '.join(self.tokens)}\n"
        s += f"labels: {' '.join(str(i) for i in self.label_ids)}\n"
        return s

class Featues:
    def __init__(self, token_ids, input_mask, label_ids):
        self.token_ids = token_ids
        self.input_mask = input_mask
        self.label_ids = label_ids
    def __str__(self):
        return self.__repr__()
    def __repr__(self):
        s = ""
        s += f"token_ids: {' '.join(str(i) for i in  self.token_ids)}\n"
        s += f"label_ids: {' '.join(str(i) for i in self.label_ids)}\n"
        s += f"input_mask:{' '.join(str(i) for i in self.input_mask)}\n"
        return s

def read_examples(input_file, mode=0):
    examples = []
    tokens = []
    label_ids = []
    errors = 0
    if mode == 0:
        with open(input_file, 'r', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                if len(line.strip()) == 0:
                    if len(tokens)>0:
                        examples.append(Examples(tokens, label_ids))
                        tokens = []
                        label_ids = []
                    continue
                try:
                    tup = line.strip().split('\t')
                    token, seg_tag, pos_tag, entity_tag = tup
                except ValueError:
                    errors +=1
                    print("Num errors: ", errors)
                    continue
                tokens.append(full2half(token.lower()))
                label_ids.append(label2id_dict[seg_tag])
    elif mode == 1:
        with open(input_file, 'r', encoding='utf-8') as f:
            for idx, line in enumerate(f):
                tokens = list(full2half(line.lower().strip().replace(' ', '')))
                label_ids = [0] * len(tokens)
                examples.append(Examples(tokens, label_ids))

    return examples

def convert_example_to_features(examples, tokenizer, max_seq_length,
                                cls_token='[CLS]', sep_token='[SEP]', pad_token_id=0, txtmode=0):
    features = []
    PAD = '[PAD]'

    overlap = 8
    pad_label = [label2id_dict['O']]
    stride = max_seq_length - overlap * 2

    tokens    = []
    token_ids = []
    label_ids = []
    examples_rsp = []
    ##reshape example.
    #head pad
    tokens.extend([PAD] * overlap)
    token_ids.extend([0] * overlap)
    label_ids.extend([0] * overlap)
    for example in examples:
        stokens = [cls_token] + example.tokens + [sep_token]
        stoken_ids = tokenizer.convert_tokens_to_ids(stokens)
        slabel_ids = pad_label + example.label_ids + pad_label

        tokens.extend(stokens)
        token_ids.extend(stoken_ids)
        label_ids.extend(slabel_ids)
    #tail pad
    tokens.extend([PAD] * overlap)
    token_ids.extend([0] * overlap)
    label_ids.extend([0] * overlap)
    assert len(token_ids) == len(label_ids)
    length = len(token_ids)
    start = 0

   #eager pad.
    while start < length:
        end = start + max_seq_length
        while end < length and id2label_dict[label_ids[end - 1]][0] not in 'ESO':
            end -= 1

        toks = tokens[start:end]
        ids = token_ids[start:end]
        labs = label_ids[start:end]
        masks = [1] * len(ids)

        pad_len = max_seq_length - len(ids)

        toks = toks + [PAD] * pad_len
        ids = ids + [pad_token_id] * pad_len
        masks = masks + [0] * pad_len
        labs = labs + pad_label * pad_len
        examples_rsp.append(Examples(toks, labs))

        features.append(Featues(token_ids=ids, input_mask=masks, label_ids=labs))

        if start == 0:
            print(f'examples 0 show:\n {examples_rsp[0].__str__()}')
        start += stride
        while start < length and id2label_dict[label_ids[start]][0] not in 'BSO':
            start += 1

    return examples_rsp, features

def read_features(input_file,  max_seq_length=160, tokenizer=None, cls_token='[CLS]', sep_token='[SEP]', pad_token_id=0, txtmode=0, dump=True):
    cached_features_file = input_file +f'_{max_seq_length}.bin'
    if os.path.exists(cached_features_file) and dump is True:
        with open(cached_features_file,'rb') as f:
            examples, features = pickle.load(f)
    else:
        raw_examples = read_examples(input_file, txtmode)
        examples, features = convert_example_to_features(raw_examples, tokenizer,max_seq_length,cls_token,sep_token,pad_token_id, txtmode)
        if dump is True:
            with open(cached_features_file, 'wb') as f:
                pickle.dump([examples, features],f)

    all_token_ids = torch.tensor([f.token_ids for f in features],dtype=torch.long)
    all_input_mask = torch.tensor([f.input_mask for f in features],dtype=torch.long)
    all_label_ids = torch.tensor([f.label_ids for f in features], dtype=torch.long)

    dataset = TensorDataset(all_token_ids,all_input_mask,all_label_ids)

    return examples, dataset




if __name__ == '__main__':
    pass